###############################################################################-
# Author: Pietryka
# Contact: matthew.pietryka@gmail.com
# Purpose: create Figure 5: Comparing peers and parents, by gender
# Notes:
###############################################################################-


#  1. Load Packages  =====================

library(dplyr)
library(tidyr)
library(purrr)
library(readr)
library(estimatr)
library(janitor)
library(ggplot2)
library(forcats)
library(stringr)
library(glue)

#  2. Load Data =====================

# source("Analysis 1 - Load Data - v5.R")
stacked_df <- read_rds("data-files/stacked_df.rds")

source("Plot Preferences.r")




#  3. Functions for plots =====================

# remove trailing zero from string
remove_trailing_zero <- function(x) str_replace(x, "(\\d{4})0", "\\1")

# regression with robust standard errors
ols_robust <- function(a_formula) {
  lm_robust(a_formula,  data = stacked_df)
}

predict_y <- function(mod){
  
  
  x_name <-  attr(mod$terms, "term.labels") |> 
    keep(~str_detect(.x, "_rm$|_f$"))
  xz_name = paste(x_name, "is_woman", sep = ":")
  intercept = coef(mod)[["(Intercept)"]]
  
  
  new_df <- tibble(
    `(Intercept)` = rep(1, n = 4),
    x_value =  c(0, 1, 0, 1),
    is_woman = c(0, 0, 1, 1),
    xz_value       = x_value*is_woman
  ) |> 
    rename(!!x_name := x_value, !!xz_name := xz_value)
  
  predict_df <- predict(mod, newdata = new_df, interval =  "confidence") |> 
    as.data.frame() |> 
    as_tibble() |> 
    rename(yhat = fit.fit, lb = fit.lwr, ub = fit.upr)
  
  new_df |> 
    select(everything(), x_value := !!x_name, -`(Intercept)`, -!!xz_name) |> 
    bind_cols(predict_df) 
  
  
}


# 4. Generate combinations of variables =============

y_vars <- c(
  "turnout_2016_f", "turnout_2018_f", "turnout_2020_f"
)

x_vars <- c(
  "turnout_2016_rm", 
  "parents_turnout_mean0_f" 
)

combos_df <- crossing(y = seq_along(y_vars), x = seq_along(x_vars)) %>% 
  mutate(
    y_name = y_vars[y],
    x_name = x_vars[x]
  ) %>% 
  mutate(
    y_year = remove_trailing_zero(y_name) %>% parse_number(),
    x_year = remove_trailing_zero(x_name) %>% parse_number()
  ) 




# 5. Calculate turnout by each X variable in each year ==============

ols_long_df <- combos_df %>%
  mutate(
    xz_name = paste(x_name, "is_woman", sep = "*"),
    the_form = map2(xz_name, y_name,  ~paste(.y, .x, sep = "~") %>% 
                      as.formula()),
    ols_mod = map(the_form, ols_robust),
    predict_df = map(ols_mod,  predict_y)
  ) %>%
  select(-x, -y) %>% 
  clean_names() %>% 
  rename(x = x_name, y = y_name)






# 6. clean the names, years, etc ==================

clean_var_names <- function(x){
  x %>% 
    remove_trailing_zero() %>% 
    str_replace("_f", "_a") %>% 
    str_replace("_rm", "_b") %>% 
    str_replace("mean0", "2008-14")  %>% 
    str_remove("_turnout")  %>% 
    str_replace("^turnout_", "student_")
}


ols_clean_df <- ols_long_df %>% 
  unnest(predict_df) %>% 
  mutate(across(c(x, y), clean_var_names)) %>% 
  separate(x, into = c("type_x", "year_x", "roommate_x"), sep = "_") %>% 
  separate(y, into = c("type_y", "year_y", "roommate_y"), sep = "_") %>% 
  unite(col = "lab", type_y:roommate_x, remove = FALSE) 

ols_clean_df







# 7. Make the plot =================================


plot_df <-  ols_clean_df %>% 
  mutate(
    type = ifelse(type_x == "parents", "Parents, 2008-14", "Roommate in 2016"),
    facet_lab = glue("Turnout in {year_y},\namong students whose ...")
  )  %>% 
  mutate(x_value_unique = case_when(
    x_value == 0 & type_x == "student"   ~ 5L,
    x_value == 1 & type_x == "student" ~ 4L,
    x_value == 0 & type_x == "parents" ~ 2L,
    x_value == 1 & type_x == "parents"  ~ 1L
  )) |> 
  mutate(gender_lab = ifelse(
    is_woman == 1, 
    "Women", 
    "Men"
  )
  ) |> 
  mutate(type_x_lab = case_match(
    type_x, "student" ~ "Roommate", "parents" ~ "Parents"
  ) |> fct_rev())



gender_se_plot <- plot_df %>% 
  ggplot(aes(x = x_value_unique, y = yhat, shape = factor(x_value_unique), color = type_x, fill = type_x)) +
  facet_grid(gender_lab+year_y~type_x_lab #, switch = "y"
  ) +
  geom_linerange( aes(ymin = lb, ymax = ub), linewidth = 0.9) +
  geom_point(size = 2.5, fill = "white", stroke = 1.5) +
  coord_flip() +
  format_plot(base_size = 14)   +
  xlab(NULL) +
  ylab("Percent voting") +
  ggtitle("Turnout by year and gender", "Among students whose ...")  +
  scale_colour_grey(start = 0.4, end = 0) +
  scale_y_continuous(
    labels = scales::label_percent(accuracy = NULL, scale = 100),
    breaks = c(seq(.40, .80, .2))
  ) +
  scale_x_continuous(
    breaks = names(se_plot_x_labs) |> as.integer(), 
    labels   = se_plot_x_labs, 
    expand = expansion(mult = 0.12, add = 0)
  ) +
  scale_shape_manual(values = se_plot_shapes) +
  theme(
    legend.position = "none",
    panel.spacing.y =  unit(1.5, "lines"),
    strip.text = element_text(
      family = "Arial Narrow",
      face = "bold",
      size = rel(1.1),
      hjust = 0.5
    ),
    strip.placement = "outside",
    strip.text.y.right = element_text(angle = 0, vjust = 1.1),
    plot.title = element_text(family = "Arial Narrow",  margin = margin(10,0,0,0)),
    plot.subtitle = element_text(family = "Arial Narrow",  margin = margin(0,0,0,0)),
    plot.title.position = "plot"
  )



graphics.off()
windows(width = 6.5, height = 7)

gender_se_plot 


# 8. Save =============================================

showtext_begin()
showtext_opts(dpi = 600)
ggsave(gender_se_plot, filename = "Results/gender_se_plot.png", 
       height = 7, width = 6.5)
ggsave(gender_se_plot, filename = "Results/gender_se_plot.tiff", 
       height =  7, width = 6.5)
showtext_end()
